import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import sys
wdir = os.path.abspath('.')
sys.path.append(wdir)
from data import prior_box_config as cfg
from layers.box_utils import match, log_sum_exp, refine_match

class RefineDetMuliBoxLoss(nn.Module):
    """SSD Weighted Loss Function
    Compute Targets:
        1) Produce Confidence Target Indices by matching  ground truth boxes
           with (default) 'priorboxes' that have jaccard index > threshold parameter
           (default threshold: 0.5).
        2) Produce localization target by 'encoding' variance into offsets of ground
           truth boxes and their matched  'priorboxes'.
        3) Hard negative mining to filter the excessive number of negative examples
           that comes with using a large number of default bounding boxes.
           (default negative:positive ratio 3:1)
    Objective Loss:
        L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
        Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
        weighted by α which is set to 1 by cross val.
        Args:
            c: class confidences,
            l: predicted boxes,
            g: ground truth boxes
            N: number of matched default boxes
        See: https://arxiv.org/pdf/1512.02325.pdf for more details.
    """
    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True, theta=0.01, use_ARM=False): 
        super(RefineDetMuliBoxLoss, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes
        self.threshold = overlap_thresh
        self.background_label = bkg_label
        self.encode_target = encode_target
        self.use_prior_for_matching = prior_for_matching
        self.do_neg_mining = neg_mining
        self.negpos_ratio = neg_pos
        self.neg_overlay = neg_overlap
        self.variance = cfg['variance']
        self.theta = theta
        self.use_ARM = use_ARM  # calculate loss for ODM when true, ARM for false

    def forward(self, predictions, targets):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net.
                conf shape: torch.size(batch_size,num_priors,num_classes)
                loc shape: torch.size(batch_size,num_priors,4)
                priors shape: torch.size(num_priors,4)

            targets (tensor): Ground truth boxes and labels for a batch,
                shape: [batch_size,num_objs,5] (last idx is the label).
        """   
        arm_loc_data, arm_conf_data, odm_loc_data, odm_conf_data, priors = predictions
        if self.use_ARM:
            loc_data, conf_data = odm_loc_data, odm_conf_data
        else:
            loc_data, conf_data = arm_loc_data, arm_conf_data
        
        batch_size = loc_data.size(0)
        # ???推断可能是为了程序正常运行强行让预选框的数量等于ARM或者ODM输出框的数量
        priors = priors[:loc_data.size(1),:]
        num_priors = priors.size(0)
        num_classes = self.num_classes

        # match priors (default boxes) and ground truth boxes
        loc_t = torch.Tensor(batch_size, num_priors, 4)
        conf_t = torch.LongTensor(batch_size, num_priors)
        for idx in range(batch_size):
            truths = targets[idx][:, :-1].detach()   # 去掉了最后一列标签数据
            labels = targets[idx][:, -1].detach()   # 提取最后一列标签数据
            if num_classes == 2: # 说明此时计算的是ARM loss只要有物体即为正标记,此时所有的GT都被标记为正
                labels = labels >= 0
            defaults = priors.detach()
            if self.use_ARM:
                refine_match(self.threshold, truths, defaults, self.variance, labels,
                            loc_t, conf_t, idx, arm_loc_data[idx].detach())
            else:
                refine_match(self.threshold, truths, defaults, self.variance, labels,
                            loc_t, conf_t, idx)
        if self.use_gpu:
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()

        loc_t.requires_grad = False
        conf_t.requires_grad = False

        if self.use_ARM:    # 去掉ARM筛选过的负样本
            P = F.softmax(arm_conf_data, 2)
            arm_conf_temp = P[:,:,1]
            object_score_index = arm_conf_temp <= self.theta
            pos = conf_t > 0
            pos[object_score_index] = 0
        else:
            pos = conf_t > 0
        

        # Localization Loss (Smooth L1)
        # Shape: [batch,num_priors,4] ???
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
        loc_p = loc_data[pos_idx].view(-1, 4)   # 输入的数据
        loc_t = loc_t[pos_idx].view(-1, 4)  # 需要学习的目标数据
        loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
        # Compute max conf across batch for hard negative mining
        batch_conf = conf_data.view(-1, self.num_classes)        
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))

        # Hard Negative Mining 挑出难学的负样本
        loss_c[pos.view(-1, 1)] # 暂时滤除正样本
        loss_c = loss_c.view(batch_size, -1)
        _, loss_idx = loss_c.sort(1, descending=True)
        _, idx_rank = loss_idx.sort(1)
        num_pos = pos.long().sum(1, keepdim=True)
        num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1) - 1)   # 将负样本的数量控制在[正样本×系数,正样本-1]之间
        neg = idx_rank < num_neg.expand_as(idx_rank)    # 把rank小于num_neg的位置设为True

        # Confidence Loss Including Positive and Negative Examples
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
        targets_weighted = conf_t[(pos+neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')

        # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
        N = num_pos.detach().sum().float()

        loss_l /= N
        loss_c /= N

        return loss_l, loss_c



            

if __name__ == '__main__':
    t = torch.rand((2,3,4))
    print(t,t.shape)
    t0 = t[0][:, :-1]
    print(t0, t0.shape)
    print(t0.data)